{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json \n",
    "import torch\n",
    "# import clip\n",
    "from PIL import Image\n",
    "# import sng_parser\n",
    "from tqdm import tqdm \n",
    "import codecs\n",
    "import numpy as np\n",
    "import csv\n",
    "import sys\n",
    "\n",
    "from io import BytesIO\n",
    "import base64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Explore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100it [00:00, 14325.30it/s]\n"
     ]
    }
   ],
   "source": [
    "csv.field_size_limit(sys.maxsize)\n",
    "\n",
    "# path_data =  '/data/mshukor/data/ofa/pretrain_example/vision_language_examples.tsv'\n",
    "# selected_cols='0,1,2,3,4,5,6,7'\n",
    "\n",
    "# path_data =  '/data/mshukor/data/ofa/pretrain_example/detection_examples.tsv'\n",
    "# selected_cols='0,1,2'\n",
    "\n",
    "# path_data =  '/data/mshukor/data/ofa/pretrain_example/image_examples.tsv'\n",
    "# selected_cols='0,1,2'\n",
    "\n",
    "path_data =  '/data/mshukor/data/ofa/pretrain_example/text_examples.tsv'\n",
    "selected_cols='0,1'\n",
    "\n",
    "data_example = []\n",
    "\n",
    "selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
    "\n",
    "with open(path_data) as file:\n",
    "    tsv_file = csv.reader(file, delimiter='\\t')\n",
    "    for line in tqdm(tsv_file):\n",
    "\n",
    "        d = [line[i] for i in selected_col_ids]\n",
    "#         print(d)\n",
    "        data_example.append(d)\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['100',\n",
       " '...please depart this field clean unless you might be answering the question. do not ask questions you already know the answer to. thanks.retrieved from \" \" ad blocker interference detected! wikia is a single-to-usefulness web site that makes cash from promoting. we\\'ve a experience for viewers using ad blockers wikia shouldn\\'t be if youve made further modifications. take away the custom ad blocker (s) and the page leave timber as expected. categories : un-answered questionsadd class cancelsave per the reddit twine, flac files will be synced to an ios gadget via icloud impel, then accessed through thenew information utility , which will allow for local playback of the excessive-high quality audio files straight by the side of the device. if , it could stamp the first time that apple has offered help for the favored flac format an ios gadget.']"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['7',\n",
       " 'perhaps the clearest indication of who won and lost came quickly on the heels of the event itself: the democratic post-debate message was that joe biden scored a clear win; the republican message was that joe biden was too mean to paul ryan. the former is a boast of success; the latter is an excuse for failure. in the larger context, it’s hard to overstate how much democrats needed a shot in the arm like this. the surface-level goals of any vice presidential debate is for the candidates to demonstrate a capacity to step up in the event of a crisis, while defending their ticket’s agenda and knocking their rivals’ agenda. but for biden, the overarching benefit was about the basic morale of his party with less than four weeks to go until election day: he wanted to give democratic voters something to feel good about, and he did. who the hell am i! i’m a liberal that is extreme in some ways and not in others. i support president obama and make no apologies for it. i think he has done a phenomenal job, especially when you consider that he inherited a huge mess and has faced unprecedented opposition from a lazy & desperate republican party. i’m a film producer/director/editor, adjunct professor, technician, media critic and photographer when i’m not reading left wing blogs and typing on this one. – on twitter @extremeliberal or email at liberalforreal (at) gmail.com own an important part of american history! cicely tyson narrates this award winning documentary that tells the story of african american migration from the old south to the prosperous north. winner of 5 awards including \"best film\" at the astoria international film festival, the \"paul robeson award\" at the newark black film festival and \"best film relating to the black experience\" at the xxv international black cinema berlin/germany!']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_[6]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = set()\n",
    "datasets = set()\n",
    "\n",
    "for d in data:\n",
    "    tasks.add(d[-1])\n",
    "    datasets.add(d[-2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int(data[10][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# len(data[0][2:][0].split(' '))\n",
    "# len(data[0][1])\n",
    "text = data[10][1]\n",
    "print(len(text.split(' ')))\n",
    "print(len(text))\n",
    "from nltk.tokenize import word_tokenize\n",
    "len(word_tokenize(text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nltk.tokenize.treebank import TreebankWordDetokenizer\n",
    "TreebankWordDetokenizer().detokenize(word_tokenize(text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = 'refcoco_train'\n",
    "index = -2\n",
    "for d in data:\n",
    "    if d[index] == key:\n",
    "        print(d[2:])\n",
    "#         break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d[4].split(',')\n",
    "str([287.0, 127.0, 340.0, 162.0])\n",
    "'{:.2f},{:.2f},{:.2f},{:.2f}'.format(287.0, 127.0, 340.0, 162.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(data))\n",
    "data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_captions_path = '/data/mshukor/data/ofa/pretrain_example/negative_sample/all_captions.txt'\n",
    "all_objects_path = '/data/mshukor/data/ofa/pretrain_example/negative_sample/object.txt'\n",
    "\n",
    "all_object_list = [\n",
    "    row.strip() for row in open(all_objects_path) if row.strip() != ''\n",
    "]\n",
    "all_caption_list = [\n",
    "    row.strip() for row in open(all_captions_path) if row.strip() != ''\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(all_object_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(all_caption_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_object_list[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_caption_list[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "json_path = '/data/mshukor/data/ofa/pretrain_example/negative_sample/type2ans.json'\n",
    "type2ans = json.load(open(json_path,'r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "type2ans.keys()\n",
    "# type2ans['what color is the']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Our data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "181767it [00:02, 70482.72it/s]\n"
     ]
    }
   ],
   "source": [
    "# path_data =  '/data/mshukor/data/ofa/pretrain_ours/text_mini.tsv'\n",
    "# selected_cols='0,1'\n",
    "\n",
    "path_data =  '/data/mshukor/data/ofa/pretrain_ours/detection_mini.tsv'\n",
    "selected_cols='0,1,2'\n",
    "\n",
    "# path_data =  '/data/mshukor/data/ofa/pretrain_ours/vision_language_mini.tsv'\n",
    "# selected_cols='0,1,2,3,4,5,6,7'\n",
    "\n",
    "# path_data =  '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
    "# selected_cols='0,1,2'\n",
    "\n",
    "data = []\n",
    "\n",
    "selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
    "\n",
    "with open(path_data) as file:\n",
    "    tsv_file = csv.reader(file, delimiter='\\t')\n",
    "    for line in tqdm(tsv_file):\n",
    "\n",
    "        d = [line[i] for i in selected_col_ids]\n",
    "#         print(d)\n",
    "        data.append(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# new_data = []\n",
    "# for d in data:\n",
    "#     label_list = d[2].strip().split('&&')\n",
    "#     new_label_list = []\n",
    "#     for label in label_list:\n",
    "#         lab = label.strip().split(',', 5)[:4] # x0, y0, x1, y1, cat_id, cat\n",
    "        \n",
    "#         if any([\"&\" in l for l in lab]):\n",
    "#             lab = [remove_special(l) for l in lab]\n",
    "            \n",
    "#             print(lab)\n",
    "#         lab_ = lab + label.strip().split(',', 5)[4:]\n",
    "#         lab_ = ','.join(lab_)\n",
    "#         new_label_list.append(lab_)\n",
    "#     new_label_list = ['&&'.join(new_label_list)]\n",
    "#     new_data.append(d[:2]+new_label_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['&40.000', '155.000', '44.000', '164.000']\n"
     ]
    }
   ],
   "source": [
    "for d in data:\n",
    "    label_list = d[2].strip().split('&&')\n",
    "    new_label_list = []\n",
    "    for label in label_list:\n",
    "        lab = label.strip().split(',', 5)[:4] # x0, y0, x1, y1, cat_id, cat\n",
    "        \n",
    "        if any([\"&\" in l for l in lab]):\n",
    "            print(lab)\n",
    "            # lab = [remove_special(l) for l in lab]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['0',\n",
       " 'coco/train2014/COCO_train2014_000000057870.jpg',\n",
       " '1.020,279.960,534.110,480.000,67,dining table&&90.670,271.490,262.510,480.000,62,chair&&233.290,270.450,403.610,473.810,62,chair&&367.820,264.270,506.970,480.000,62,chair&&476.760,261.030,596.490,462.740,62,chair&&263.030,174.370,417.670,299.400,64,potted plant&&539.330,290.160,640.000,469.210,62,chair&&10.790,260.030,125.120,384.070,62,chair&&560.800,413.950,639.090,479.200,67,dining table&&20.540,376.760,103.780,431.890,62,chair&&1.080,373.210,32.360,480.000,62,chair&&298.200,235.170,381.210,269.250,86,vase&&152.170,256.670,230.580,285.780,62,chair&&364.400,256.570,417.060,283.210,62,chair&&296.780,277.790,329.260,289.780,84,book&&292.800,289.310,314.210,300.650,84,book&&285.800,257.460,299.770,273.600,62,chair']"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['1.020,279.960,534.110,480.000,67,dining table',\n",
       " '90.670,271.490,262.510,480.000,62,chair',\n",
       " '233.290,270.450,403.610,473.810,62,chair',\n",
       " '367.820,264.270,506.970,480.000,62,chair',\n",
       " '476.760,261.030,596.490,462.740,62,chair',\n",
       " '263.030,174.370,417.670,299.400,64,potted plant',\n",
       " '539.330,290.160,640.000,469.210,62,chair',\n",
       " '10.790,260.030,125.120,384.070,62,chair',\n",
       " '560.800,413.950,639.090,479.200,67,dining table',\n",
       " '20.540,376.760,103.780,431.890,62,chair',\n",
       " '1.080,373.210,32.360,480.000,62,chair',\n",
       " '298.200,235.170,381.210,269.250,86,vase',\n",
       " '152.170,256.670,230.580,285.780,62,chair',\n",
       " '364.400,256.570,417.060,283.210,62,chair',\n",
       " '296.780,277.790,329.260,289.780,84,book',\n",
       " '292.800,289.310,314.210,300.650,84,book',\n",
       " '285.800,257.460,299.770,273.600,62,chair']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "label_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_special(input_string):\n",
    "    final_string = \"\"\n",
    "    for character in input_string:\n",
    "        if  character == \" \":\n",
    "            final_string = final_string + character\n",
    "        else:\n",
    "            if(character.isalnum()):\n",
    "                final_string = final_string + character\n",
    "    return final_string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█| 5593207/5593207 [00:35<00:0\n"
     ]
    }
   ],
   "source": [
    "for d in tqdm(data):\n",
    "    label = d[2]\n",
    "    d[2] = remove_special(caption)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['0',\n",
       " 'coco/train2014/COCO_train2014_000000057870.jpg',\n",
       " '1.020,279.960,534.110,480.000,67,dining table&&90.670,271.490,262.510,480.000,62,chair&&233.290,270.450,403.610,473.810,62,chair&&367.820,264.270,506.970,480.000,62,chair&&476.760,261.030,596.490,462.740,62,chair&&263.030,174.370,417.670,299.400,64,potted plant&&539.330,290.160,640.000,469.210,62,chair&&10.790,260.030,125.120,384.070,62,chair&&560.800,413.950,639.090,479.200,67,dining table&&20.540,376.760,103.780,431.890,62,chair&&1.080,373.210,32.360,480.000,62,chair&&298.200,235.170,381.210,269.250,86,vase&&152.170,256.670,230.580,285.780,62,chair&&364.400,256.570,417.060,283.210,62,chair&&296.780,277.790,329.260,289.780,84,book&&292.800,289.310,314.210,300.650,84,book&&285.800,257.460,299.770,273.600,62,chair']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█| 181767/181767 [00:00<00:00,\n"
     ]
    }
   ],
   "source": [
    "for d in tqdm(data):\n",
    "    d[2] = d[2].replace('\\\"', '')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_  = []\n",
    "with open(path_data) as file:\n",
    "    for i in tqdm(range(6458670)):\n",
    "        column_l = file.readline().rstrip(\"\\n\").split(\"\\t\")\n",
    "        data_.append(column_l)\n",
    "        if len(column_l) < 2:\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "5593207it [00:03, 1463300.52it/s]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "data_example  = []\n",
    "fp = open('/data/mshukor/data/ofa/pretrain_ours/vision_language_mini_.tsv', \"r\")\n",
    "data_example  = []\n",
    "for line in tqdm(fp):\n",
    "    data_example.append(line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2796604\tcc3m/train/8/2d0d96e4ecb8e2e959a3bf10d59b9d05ac114aea.jpg\tthe residential development under construction in district\t\t\t\tcc3m\tcaption\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(data_example[2796604])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/val2014/COCO_val2014_000000329789.jpg\tA young man is eating a slice of pizza in his room\t\t\t\tcoco_karp\tcaption\n",
      "\n"
     ]
    }
   ],
   "source": [
    "data_example[2796604]\n",
    "fp.seek(2796604)\n",
    "for l in fp:\n",
    "    print(l)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2317 2317\n",
      "2510 2514 2510\n"
     ]
    }
   ],
   "source": [
    "len(data_[10].rstrip(\"\\n\").split(\"\\t\")[1])# len(line)\n",
    "# len(data_[10].rstrip(\"\\n\").split(\"\\t\")[1].encode('utf-8'))\n",
    "print(len(data_example[10]), len(data_example[10].encode('utf-8')))\n",
    "print(len(data_[10]), len(data_[10].encode('utf-8')), len(data_[10].encode('utf-8').decode(\"utf-8\")))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data_[10].encode('utf-8'))\n",
    "print(data_[10])\n",
    "\n",
    "print(data_example[10].encode('utf-8'))\n",
    "print(data_example[10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "6458670it [01:45, 61129.36it/s] \n"
     ]
    }
   ],
   "source": [
    "output_path = '/data/mshukor/data/ofa/pretrain_ours/text_mini.tsv'\n",
    "\n",
    "fp = open(output_path, \"r\")\n",
    "data_  = []\n",
    "for line in tqdm(fp):\n",
    "    data_.append(line)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "6458670it [04:08, 25941.37it/s]\n"
     ]
    }
   ],
   "source": [
    "output_path = '/data/mshukor/data/ofa/pretrain_ours/text_mini.tsv'\n",
    "\n",
    "start_id = 0 \n",
    "num_max_characters = 2500\n",
    "\n",
    "with open(output_path, 'w', newline='\\n') as f_output:\n",
    "    csv_output = csv.writer(f_output, delimiter='\\t')\n",
    "\n",
    "    for i, t in tqdm(enumerate(data)):\n",
    "        text = t[1]\n",
    "        item = [start_id, text]\n",
    "        csv_output.writerow(item)\n",
    "        start_id+=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████| 181767/181767 [00:03<00:00, 51934.09it/s]\n"
     ]
    }
   ],
   "source": [
    "output_path = '/data/mshukor/data/ofa/pretrain_ours/detection_mini.tsv'\n",
    "\n",
    "with open(output_path, 'w', newline='\\n') as f_output:\n",
    "    csv_output = csv.writer(f_output, delimiter='\\t')\n",
    "\n",
    "    for t in tqdm(data):\n",
    "        csv_output.writerow(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['5593206',\n",
       " 'train2014/COCO_train2014_000000524286.jpg',\n",
       " '',\n",
       " 'Is that a laptop?',\n",
       " '1.0|!+yes',\n",
       " '',\n",
       " 'vqa_train',\n",
       " 'qa']"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create data tsv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_img_to_str(file_name):\n",
    "    img = Image.open(file_name) # path to file\n",
    "    img_buffer = BytesIO()\n",
    "    img.save(img_buffer, format=img.format)\n",
    "    byte_data = img_buffer.getvalue()\n",
    "    base64_str = base64.b64encode(byte_data) # bytes\n",
    "    base64_str = base64_str.decode(\"utf-8\") # str\n",
    "    return base64_str"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create VL tsv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Caption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "original_data_path = '/data/mshukor/data/our_albef_data/json_pretrain/sbu.json'\n",
    "original_data = json.load(open(original_data_path,'r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from preprocess.utils import get_tsv_data_from_jsons\n",
    "    \n",
    "datasets = [\n",
    "            '/data/mshukor/data/our_albef_data/json_pretrain/coco_karp.json',\n",
    "            '/data/mshukor/data/our_albef_data/json_pretrain/vg_albef.json',\n",
    "            '/data/mshukor/data/our_albef_data/json_pretrain/sbu.json',\n",
    "            '/data/mshukor/data/our_albef_data/json_pretrain/cc3m.json', \n",
    "    \n",
    "            ['/data/mshukor/data/refcoco/refcoco+/refs(unc).p', '/data/mshukor/data/refcoco/refcoco+/instances.json'],\n",
    "            \n",
    "            '/data/mshukor/data/our_albef_data/data/vqa_train.json',\n",
    "]\n",
    "\n",
    "start_id = 0\n",
    "task_types = ['caption',\n",
    "             'caption',\n",
    "             'caption',\n",
    "             'caption',\n",
    "             'visual_grounding',\n",
    "             'qa',]\n",
    "\n",
    "tsvs = get_tsv_data_from_jsons(datasets, start_id, task_types, convert_images=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(tsvs)\n",
    "# tsvs[-10000]\n",
    "tsvs[-1000000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "from io import StringIO\n",
    "\n",
    "output_path = '/data/mshukor/data/ofa/pretrain_ours/vision_language_mini.tsv'\n",
    "\n",
    "with open(output_path, 'w', newline='') as f_output:\n",
    "    csv_output = csv.writer(f_output, delimiter='\\t')\n",
    "\n",
    "    for t in tqdm(tsvs):\n",
    "        csv_output.writerow(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "csv.field_size_limit(sys.maxsize)\n",
    "\n",
    "\n",
    "out_data = []\n",
    "selected_cols='0,1,2,3,4,5,6,7'\n",
    "\n",
    "selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
    "\n",
    "with open(output_path) as file:\n",
    "    tsv_file = csv.reader(file, delimiter='\\t')\n",
    "    for line in tqdm(tsv_file):\n",
    "        d = [line[i] for i in selected_col_ids]\n",
    "        out_data.append(d)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_data[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### VQA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "original_data_path = '/data/mshukor/data/our_albef_data/data/vqa_train.json'\n",
    "original_data = json.load(open(original_data_path,'r'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "original_data[100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1.0|!+horizontal&&0.3|!+south&&0.3|!+straight&&0.3|!+vertical\n",
    "\n",
    "from preprocess.utils import get_tsv_vqa_data_from_json\n",
    "\n",
    "\n",
    "start_id = 0\n",
    "dataset_name = 'vqav2'\n",
    "task_type = 'qa'\n",
    "\n",
    "image_root = '/data/mshukor/data/coco'\n",
    "tmp = get_tsv_vqa_data_from_json(original_data, start_id, dataset_name, task_type, image_root=image_root, convert_images=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp[10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Visual Grounding "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "original_data_path = '/data/mshukor/data/our_albef_data/data/refcoco+_train.json'\n",
    "original_data = json.load(open(original_data_path,'r'))\n",
    "\n",
    "original_data_path = '/data/mshukor/data/our_albef_data/data/refcoco+/dets.json'\n",
    "det_file = json.load(open(original_data_path,'r'))\n",
    "\n",
    "original_data_path = '/data/mshukor/data/our_albef_data/data/refcoco+/cocos.json'\n",
    "coco_file = json.load(open(original_data_path,'r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(det_file.keys())[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "original_data_path = '/data/mshukor/data/refcoco/refcoco+/instances.json'\n",
    "original_data = json.load(open(original_data_path,'r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "ref_path = '/data/mshukor/data/refcoco/refcoco+/refs(unc).p'\n",
    "refs = pickle.load(open(ref_path, 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, ref in tqdm(enumerate(refs)):\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(refs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "refs[500]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_to_annot = {}\n",
    "for annot in original_data['annotations']:\n",
    "    id_to_annot[annot['id']] = annot\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_to_images = {}\n",
    "for annot in tqdm(original_data['images']):\n",
    "    id_to_images[annot['id']] = annot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_to_images[576457]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(id_to_annot.keys())[:10]\n",
    "id_to_annot[1640859]['bbox']\n",
    "for r in tqdm(id_to_annot.values()):\n",
    "    if r['bbox'][0] > 0:\n",
    "        print(r['bbox'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(original_data.keys())[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ref_path = '/data/mshukor/data/refcoco/refcoco+/refs(unc).p'\n",
    "instances_path = '/data/mshukor/data/refcoco/refcoco+/instances.json'\n",
    "start_id = 0\n",
    "dataset_name='refcoco_train'\n",
    "task_type='visual_grounding'\n",
    "convert_images=False\n",
    "split='train'\n",
    "\n",
    "tmp = get_tsv_from_refcoco(ref_path, instances_path, start_id, dataset_name=dataset_name, task_type=task_type, convert_images=convert_images, split=split)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Image.open('/data/mshukor/data/coco/train2014/COCO_train2014_000000000072.jpg').convert('RGB')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "original_data['images'][:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ['third book starting from left', '', '29.1,11.72,66.81,343.41', '', 'refcoco_train', 'visual_grounding']\n",
    "\n",
    "original_data['categories']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imagenet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image-id and image base64 string .txt file \n",
    "# id, image, code in tsv final \n",
    "\n",
    "from preprocesss.utils import create_imagenet_txt_files\n",
    "\n",
    "\n",
    "path_data = '/data/mshukor/data/imagenet/val'\n",
    "output_path = '/data/mshukor/data/ofa/pretrain_ours/imagenet_val.txt'\n",
    "\n",
    "\n",
    "create_imagenet_txt_files(path_data, output_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_id\n",
    "len(data)\n",
    "data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "code_path = '/data/mshukor/data/ofa/pretrain_ours/imagenet_train_codes.tsv'\n",
    "output_path = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
    "\n",
    "def save_image_only_tsv_from_code_files(code_path, output_path, start_id=0):\n",
    "    selected_col_ids = [0,1]\n",
    "    out_data = []\n",
    "    with open(code_path) as file:\n",
    "        tsv_file = csv.reader(file, delimiter='\\t')\n",
    "        for line in tqdm(tsv_file):\n",
    "            d = [line[i] for i in selected_col_ids]\n",
    "            d = [start_id]+d\n",
    "            out_data.append(d)\n",
    "\n",
    "\n",
    "    with open(output_path, 'w', newline='') as f_output:\n",
    "        csv_output = csv.writer(f_output, delimiter='\\t')\n",
    "\n",
    "        for t in tqdm(out_data):\n",
    "            csv_output.writerow(t)\n",
    "\n",
    "save_image_only_tsv_from_code_files(code_path, output_path, start_id=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_col_ids = [0,1,2]\n",
    "out_data = []\n",
    "with open(output_path) as file:\n",
    "    tsv_file = csv.reader(file, delimiter='\\t')\n",
    "    for line in tqdm(tsv_file):\n",
    "        d = [line[i] for i in selected_col_ids]\n",
    "        out_data.append(d)\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(out_data[0][2].split(' '))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Fix image paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1281167it [00:16, 79250.80it/s]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "path_data =  '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
    "selected_cols='0,1,2'\n",
    "\n",
    "data = []\n",
    "\n",
    "selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
    "\n",
    "with open(path_data) as file:\n",
    "    tsv_file = csv.reader(file, delimiter='\\t')\n",
    "    for line in tqdm(tsv_file):\n",
    "\n",
    "        d = [line[i] for i in selected_col_ids]\n",
    "#         print(d)\n",
    "        data.append(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1281167it [00:16, 76760.12it/s]\n",
      "1281167it [00:01, 671149.72it/s]\n",
      "100%|█████| 1281167/1281167 [00:01<00:00, 947543.73it/s]\n"
     ]
    }
   ],
   "source": [
    "# from imge-id img-path to \n",
    "def replace_image_id_by_path(input_tsv, output_tsv, mapping_file):\n",
    "    selected_cols='0,1,2'\n",
    "    data = []\n",
    "    selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
    "    with open(input_tsv) as file:\n",
    "        tsv_file = csv.reader(file, delimiter='\\t')\n",
    "        for line in tqdm(tsv_file):\n",
    "            d = [line[i] for i in selected_col_ids]\n",
    "            data.append(d)\n",
    "            \n",
    "    im_id_to_path = {}\n",
    "    with open(mapping_file) as file:\n",
    "        tsv_file = csv.reader(file, delimiter='\\t')\n",
    "        for line in tqdm(tsv_file):\n",
    "            d = [line[i] for i in [0, 1]]\n",
    "            im_id_to_path[d[0]] = d[1]\n",
    "            \n",
    "    for d in tqdm(data):\n",
    "        im_id = d[1].split('/')[-1].split('.')[0]\n",
    "        im_path = im_id_to_path[im_id]\n",
    "        d[1] = im_path\n",
    "        \n",
    "    with open(output_tsv, 'w', newline='') as f_output:\n",
    "        csv_output = csv.writer(f_output, delimiter='\\t')\n",
    "\n",
    "        for t in tqdm(data):\n",
    "            csv_output.writerow(t)\n",
    "        \n",
    "    return data\n",
    "\n",
    "input_tsv = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
    "output_tsv = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
    "mapping_file = '/data/mshukor/data/ofa/pretrain_ours/imagenet_train.txt'\n",
    "\n",
    "tmp = replace_image_id_by_path(input_tsv, output_tsv, mapping_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['0',\n",
       " 'RawImages/train/n03146219/n03146219_8050.JPEG',\n",
       " '7442 662 7977 1652 6320 650 4376 992 1596 7734 1925 5335 3935 5604 5697 4504 5114 4050 144 215 144 6691 5321 7769 4755 3346 4691 3469 4175 1351 6907 9 6948 7749 7166 215 1026 931 970 4168 2675 6874 6248 2306 6138 8052 2970 6302 5550 2491 6931 969 6574 8014 6588 6639 389 1882 688 4691 4266 675 6248 3938 2387 4365 5999 261 2966 3499 651 5290 970 3526 5583 516 167 2103 1513 198 6657 7442 1118 7207 7307 1792 2078 388 4285 3417 5450 6959 6999 1306 1649 4556 2533 1103 6869 7681 8051 1916 7160 7743 2704 8063 2726 4860 2383 1635 8061 3497 7327 5915 7836 5697 1719 2136 96 970 7184 5167 2250 404 7007 7565 2742 33 7076 5250 7790 1838 1298 2847 3250 1204 1934 5550 4360 5688 1791 3465 634 4663 2991 5352 4066 4157 946 1596 3504 5855 5629 5411 7695 3627 3942 5631 2736 2883 5059 1423 2009 2643 1873 4960 1661 545 1396 3450 3145 211 6869 2226 6780 2724 4606 3702 3667 891 6236 6419 3531 7032 5277 3381 3031 7878 725 1652 1813 5037 949 3087 405 7884 3784 5432 633 4256 235 3182 3686 5450 2419 1593 7948 5741 6237 7233 20 7470 7071 182 1584 6780 7913 2691 7207 5094 5199 4502 5030 2360 448 5129 2713 1094 1678 1934 2458 2970 2133 867 3332 6138 294 3260 5495 4189 5732 3940 5629 4139 7335 7607 3248 4981 2109 3660 4364 7763 3964 7163 6702 691']"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tmp[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████| 1281167/1281167 [00:03<00:00, 336250.44it/s]\n"
     ]
    }
   ],
   "source": [
    "# imgage_dir = 'imagenet/RawImages/train/'\n",
    "# for d in tqdm(data):\n",
    "#     im_id = d[1]\n",
    "#     im_dir = im_id.split('_')[0]\n",
    "#     im_path = os.path.join(im_dir, im_id+'.JPEG')\n",
    "#     d[1] = os.path.join(imgage_dir, im_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['0',\n",
       " 'imagenet/RawImages/train/n03146219/n03146219_8050.JPEG',\n",
       " '7442 662 7977 1652 6320 650 4376 992 1596 7734 1925 5335 3935 5604 5697 4504 5114 4050 144 215 144 6691 5321 7769 4755 3346 4691 3469 4175 1351 6907 9 6948 7749 7166 215 1026 931 970 4168 2675 6874 6248 2306 6138 8052 2970 6302 5550 2491 6931 969 6574 8014 6588 6639 389 1882 688 4691 4266 675 6248 3938 2387 4365 5999 261 2966 3499 651 5290 970 3526 5583 516 167 2103 1513 198 6657 7442 1118 7207 7307 1792 2078 388 4285 3417 5450 6959 6999 1306 1649 4556 2533 1103 6869 7681 8051 1916 7160 7743 2704 8063 2726 4860 2383 1635 8061 3497 7327 5915 7836 5697 1719 2136 96 970 7184 5167 2250 404 7007 7565 2742 33 7076 5250 7790 1838 1298 2847 3250 1204 1934 5550 4360 5688 1791 3465 634 4663 2991 5352 4066 4157 946 1596 3504 5855 5629 5411 7695 3627 3942 5631 2736 2883 5059 1423 2009 2643 1873 4960 1661 545 1396 3450 3145 211 6869 2226 6780 2724 4606 3702 3667 891 6236 6419 3531 7032 5277 3381 3031 7878 725 1652 1813 5037 949 3087 405 7884 3784 5432 633 4256 235 3182 3686 5450 2419 1593 7948 5741 6237 7233 20 7470 7071 182 1584 6780 7913 2691 7207 5094 5199 4502 5030 2360 448 5129 2713 1094 1678 1934 2458 2970 2133 867 3332 6138 294 3260 5495 4189 5732 3940 5629 4139 7335 7607 3248 4981 2109 3660 4364 7763 3964 7163 6702 691']"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████| 1281167/1281167 [00:27<00:00, 46704.02it/s]\n"
     ]
    }
   ],
   "source": [
    "output_path = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
    "with open(output_path, 'w', newline='') as f_output:\n",
    "    csv_output = csv.writer(f_output, delimiter='\\t')\n",
    "\n",
    "    for t in tqdm(data):\n",
    "        csv_output.writerow(t)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Object detection"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### COCO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# '505.856,189.994,799.744,450.016,/m/07j7r,tree&&753.664,384.00600000000003,827.392,446.572,/m/0c9ph5,flower'\n",
    "\n",
    "path_json = '/data/mshukor/data/coco/annotations/instances_train2014.json'\n",
    "\n",
    "data = json.load(open(path_json,'r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_tsv_from_coco_detection(instances_path, start_id, convert_images=True, split='train'):\n",
    "\n",
    "    instances = json.load(open(instances_path,'r'))\n",
    "    imgid_to_annot = {}\n",
    "    for annot in tqdm(instances['annotations']):\n",
    "        if annot['image_id'] not in imgid_to_annot:\n",
    "            imgid_to_annot[annot['image_id']] = [annot]\n",
    "        else:\n",
    "            imgid_to_annot[annot['image_id']].append(annot)\n",
    "\n",
    "    id_to_category = {}\n",
    "    for annot in tqdm(instances['categories']):\n",
    "        id_to_category[annot['id']] = annot['name']\n",
    "\n",
    "    tsv_data = []\n",
    "    missied = []\n",
    "    for ref in tqdm(instances['images']):\n",
    "        ref_split = split\n",
    "        image_id = ref['id']\n",
    "        file_name = ref['file_name']\n",
    "\n",
    "        if ref_split == 'train':\n",
    "            file_name = os.path.join('coco/train2014', file_name)\n",
    "\n",
    "        if convert_images:\n",
    "            img_path = os.path.join('/data/mshukor/data/', file_name)\n",
    "            img = convert_img_to_str(img_path)\n",
    "        else:\n",
    "            img_path = file_name.replace('/data/mshukor/data/', '')\n",
    "            img = img_path\n",
    "\n",
    "        # ann_id = ref['id']\n",
    "        # annot = id_to_annot[ann_id]\n",
    "        if image_id not in imgid_to_annot:\n",
    "            missied.append(image_id)\n",
    "            continue\n",
    "        annots = imgid_to_annot[image_id]\n",
    "        detections = []\n",
    "        areas = []\n",
    "        for annot in annots:\n",
    "            bbox = annot['bbox'] # x,y,w,h bottom left\n",
    "            area = bbox[2]*bbox[3]\n",
    "            x1, y1, x2, y2 = bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]   # top left, bottom right \n",
    "            # box = '{:.3f},{:.3f},{:.3f},{:.3f}'.format(x1, y1, x2, y2)\n",
    "\n",
    "            object_id = annot['category_id']\n",
    "            category = id_to_category[object_id]\n",
    "\n",
    "            tmp = '{:.3f},{:.3f},{:.3f},{:.3f},{},{}'.format(x1, y1, x2, y2, object_id, category)\n",
    "            areas.append(area)\n",
    "            detections.append(tmp)\n",
    "\n",
    "        sorted_indices = sorted(range(len(areas)), key=lambda k: areas[k], reverse=True)\n",
    "        detections = [detections[k] for k in sorted_indices]\n",
    "        detections = '&&'.join(detections)\n",
    "        t = [start_id, img, detections]\n",
    "\n",
    "        tsv_data.append(t)\n",
    "        start_id+=1\n",
    "\n",
    "    return tsv_data\n",
    "\n",
    "instances_path = '/data/mshukor/data/coco/annotations/instances_train2014.json'\n",
    "start_id = 0\n",
    "tmp = get_tsv_from_coco_detection(instances_path, start_id, convert_images=False, split='train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(imgid_to_annot.keys())[:10]\n",
    "len(missied)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### VG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_tsv_from_vg_detection(instances_path, path_images, start_id, convert_images=True, split='train'):\n",
    "    \n",
    "    instances = json.load(open(instances_path,'r'))\n",
    "    \n",
    "    id_to_objects = {}\n",
    "    for d in instances:\n",
    "        id_to_objects[d['id']] = d\n",
    "\n",
    "\n",
    "    \n",
    "    id_to_image_path = {}\n",
    "    for root, dirs, files, in os.walk(path_images):\n",
    "        for d in dirs:\n",
    "            dir_path = os.path.join(root, d)\n",
    "            for _, _, dir_files in os.walk(dir_path):\n",
    "                for f in dir_files:\n",
    "                    file_path = os.path.join(dir_path, f)\n",
    "                    file_path = '/'.join(file_path.split('/')[-4:])\n",
    "                    image_id = f.split('.')[0]\n",
    "                    id_to_image_path[image_id] = file_path\n",
    "\n",
    "                    \n",
    "\n",
    "\n",
    "    tsv_data = []\n",
    "    missied = []\n",
    "    negs = []\n",
    "    for ref in tqdm(id_to_image_path.keys()):\n",
    "        ref_split = split\n",
    "        \n",
    "        image_id = ref\n",
    "        \n",
    "        file_name = id_to_image_path[image_id]\n",
    "        if convert_images:\n",
    "            img_path = os.path.join('/data/mshukor/data/', file_name)\n",
    "            img = convert_img_to_str(img_path)\n",
    "        else:\n",
    "            img_path = file_name.replace('/data/mshukor/data/', '')\n",
    "            img = img_path\n",
    "\n",
    "            \n",
    "        if int(image_id) in id_to_objects:\n",
    "            objects = id_to_objects[int(image_id)]['objects']\n",
    "        else:\n",
    "            missied.append(image_id)\n",
    "            continue\n",
    "        \n",
    "        if len(objects) == 0:\n",
    "            missied.append(image_id)\n",
    "            continue\n",
    "            \n",
    "        \n",
    "        areas = []\n",
    "        detections = []\n",
    "        for annot in objects:\n",
    "            x,y,w,h = annot['x'], annot['y'], annot['w'], annot['h'] # x,y,w,h bottom left\n",
    "            \n",
    "            area = w*h\n",
    "            \n",
    "            x1, y1, x2, y2 = x, y, x + w, y + h  # top left, bottom right \n",
    "            \n",
    "            if x1 < 0 or x2 < 0:\n",
    "                negs.append(annot)\n",
    "            x1 = max(0, x1)\n",
    "            x2 = max(0, x2)\n",
    "            \n",
    "            \n",
    "            category = ','.join(annot['names']).replace('\\x00','')\n",
    "            object_id = annot['id']\n",
    "            \n",
    "            \n",
    "            tmp = '{:.3f},{:.3f},{:.3f},{:.3f},{},{}'.format(x1, y1, x2, y2, object_id, category)\n",
    "            detections.append(tmp)\n",
    "            areas.append(area)\n",
    "\n",
    "        sorted_indices = sorted(range(len(areas)), key=lambda k: areas[k], reverse=True)\n",
    "        detections = [detections[k] for k in sorted_indices]\n",
    "        \n",
    "        detections = '&&'.join(detections)\n",
    "        t = [start_id, img, detections]\n",
    "\n",
    "        tsv_data.append(t)\n",
    "        start_id+=1\n",
    "    print('missed images:', len(missied), 'negs', len(negs))\n",
    "    return tsv_data\n",
    "\n",
    "\n",
    "instances_path = '/data/mshukor/data/visual_genome/annotations/objects.json'\n",
    "path_images = '/data/mshukor/data/visual_genome/images'\n",
    "start_id = 0\n",
    "\n",
    "tmp = get_tsv_from_vg_detection(instances_path, path_images, start_id, convert_images=False, split='train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_root = '/data/mshukor/data/'\n",
    "\n",
    "Image.open(image_root+id_to_image_path['1087']).convert('RGB')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### OpenImagesV5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_path = '/data/mshukor/data/OpenImagesV5/train-annotations-bbox.csv'\n",
    "# data_path = '/data/mshukor/data/OpenImagesV5/train-images-boxable.csv'\n",
    "# data_path = '/data/mshukor/data/OpenImagesV5/train-images-boxable-with-rotation.csv'\n",
    "data_path = '/data/mshukor/data/OpenImagesV5/class-descriptions-boxable.csv'\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "selected_col_ids = [0,1,2]\n",
    "out_data = []\n",
    "with open(data_path) as file:\n",
    "    tsv_file = csv.reader(file, delimiter='\\t')\n",
    "    for i, line in tqdm(enumerate(tsv_file)):\n",
    "        # d = [line[i] for i in selected_col_ids]\n",
    "        out_data.append(line)\n",
    "#         print(line)\n",
    "#         if i > 2:\n",
    "#             break\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_tsv_from_openimages_detection(instances_path, path_images, start_id, convert_images=False, split='train')\n",
    "\n",
    "    id_to_image_path = {}\n",
    "    for root, dirs, files, in os.walk(path_images):\n",
    "        for d in dirs:\n",
    "            dir_path = os.path.join(root, d)\n",
    "            for _, _, dir_files in os.walk(dir_path):\n",
    "                for f in dir_files:\n",
    "                    file_path = os.path.join(dir_path, f)\n",
    "                    file_path = '/'.join(file_path.split('/')[-4:])\n",
    "                    image_id = f.split('.')[0]\n",
    "                    id_to_image_path[image_id] = file_path\n",
    "\n",
    "    image_root = '/gpfsdswork/dataset'\n",
    "\n",
    "    def imagepath_to_image_size(path):\n",
    "        w, h = Image.open(path).size\n",
    "\n",
    "    id_to_annot = {}\n",
    "    with open(instances_path) as file:\n",
    "        tsv_file = csv.reader(file, delimiter='\\t')\n",
    "        for i, line in tqdm(enumerate(tsv_file)):\n",
    "            img_id = line[0].split(',')[0]\n",
    "            if img_id in id_to_annot:\n",
    "                id_to_annot[img_id].append(line)\n",
    "            else:\n",
    "                id_to_annot[img_id] = [line]\n",
    "\n",
    "    classid_to_class = {}\n",
    "\n",
    "    with open(class_path) as file:\n",
    "        tsv_file = csv.reader(file, delimiter=',')\n",
    "        for i, line in tqdm(enumerate(tsv_file)):\n",
    "            classid_to_class[line[0]] = line[1]\n",
    "\n",
    "    tsv_data = []\n",
    "    for img_id in id_to_annot.keys():\n",
    "        annots = id_to_annot[img_id]\n",
    "        img_path = id_to_image_path[img_id]\n",
    "        orig_img_path = os.path.join(image_root, img_path)\n",
    "        w, h = imagepath_to_image_size(path)\n",
    "\n",
    "        if convert_images:\n",
    "            img = convert_img_to_str(orig_img_path)\n",
    "        else:\n",
    "            img = img_path\n",
    "\n",
    "        areas = []\n",
    "        detections = []\n",
    "        for d in annots:\n",
    "            d = d[0].split(',')\n",
    "\n",
    "            x1, x2, y1, y2 = d[4:8]\n",
    "            x1, x2, y1, y2 = x1*w, x2*w, y1*h, y2*h\n",
    "            box_w, box_h = x2 - x1, y2 - y1\n",
    "            area = box_w*box_h\n",
    "            areas.append(area)\n",
    "\n",
    "            object_id = d[2]\n",
    "            category = classid_to_class[object_id]\n",
    "\n",
    "            tmp = '{:.3f},{:.3f},{:.3f},{:.3f},{},{}'.format(x1, y1, x2, y2, object_id, category)\n",
    "            detections.append(tmp)\n",
    "\n",
    "\n",
    "        sorted_indices = sorted(range(len(areas)), key=lambda k: areas[k], reverse=True)\n",
    "        detections = [detections[k] for k in sorted_indices]\n",
    "\n",
    "        detections = '&&'.join(detections)\n",
    "        t = [start_id, img, detections]\n",
    "\n",
    "        tsv_data.append(t)\n",
    "        start_id+=1\n",
    "        \n",
    "    return tsv_data\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "e39871fd9fd74f55"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### En Wikipedia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env http_proxy='http://192.168.0.100:3128' \n",
    "%env https_proxy='http://192.168.0.100:3128'\n",
    "\n",
    "%env HF_DATASETS_CACHE=\"/data/mshukor/data/.cache\"\n",
    "%env HF_DATASETS_OFFLINE=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp = load_dataset(\"wikipedia\", \"20220301.en\", cache_dir=\"/data/mshukor/data/.cache\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(tmp['train'][0]['text'])\n",
    "tmp['train'][0]['text'][:512]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_special(input_string):\n",
    "    final_string = \"\"\n",
    "    for character in input_string:\n",
    "        if  character == \" \":\n",
    "            final_string = final_string + character\n",
    "        else:\n",
    "            if(character.isalnum()):\n",
    "                final_string = final_string + character\n",
    "    return final_string\n",
    "\n",
    "def get_tsv_from_text_data(data_name=\"wikipedia\", data_subname=\"20220301.en\", \n",
    "                           output_path, cache_dir=\"/data/mshukor/data/.cache\", start_id=0, num_max_characters=2500):\n",
    "    from datasets import load_dataset\n",
    "    tmp = load_dataset(data_name, data_subname, cache_dir=cache_dir)\n",
    "\n",
    "    with open(output_path, 'w', newline='') as f_output:\n",
    "        csv_output = csv.writer(f_output, delimiter='\\t')\n",
    "\n",
    "        for i, t in tqdm(enumerate(tmp['train'])):\n",
    "            text = t['text'][:num_max_characters].replace('\\t', ' ').replace(\"\\n\", ' ').replace('\\\"', '')\n",
    "            text = remove_special(text)\n",
    "            item = [start_id, text]\n",
    "            csv_output.writerow(item)\n",
    "            start_id+=1\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "from io import StringIO\n",
    "\n",
    "output_path = '/data/mshukor/data/ofa/pretrain_ours/text_mini.tsv'\n",
    "\n",
    "start_id = 0 \n",
    "num_max_characters = 2500\n",
    "\n",
    "with open(output_path, 'w', newline='') as f_output:\n",
    "    csv_output = csv.writer(f_output, delimiter='\\t')\n",
    "\n",
    "    for i, t in tqdm(enumerate(tmp['train'])):\n",
    "        text = t['text'][:num_max_characters]\n",
    "        item = [start_id, text]\n",
    "        csv_output.writerow(item)\n",
    "        start_id+=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_data = []\n",
    "selected_cols='0,1'\n",
    "\n",
    "selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
    "\n",
    "with open(output_path) as file:\n",
    "    tsv_file = csv.reader(file, delimiter='\\t')\n",
    "    for line in tqdm(tsv_file):\n",
    "        d = [line[i] for i in selected_col_ids]\n",
    "        out_data.append(d)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create from finetuned data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "read from tsv and write to tsv directly \n",
    "same for vqa v2\n",
    "then create ofa_mini 4m, vqa and refcoco for pretraining "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "from models import ofa_base_architecture, OFAModel\n",
    "from transformers.tokenization_utils_base import BatchEncoding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Explore ofa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-11-15 08:52:08 | INFO | tasks.ofa_task | source dictionary: 59457 types\n",
      "2022-11-15 08:52:08 | INFO | tasks.ofa_task | target dictionary: 59457 types\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from fairseq import utils, tasks\n",
    "from fairseq import checkpoint_utils\n",
    "from utils.eval_utils import eval_step\n",
    "from tasks.mm_tasks.caption import CaptionTask\n",
    "from models.ofa import OFAModel\n",
    "from PIL import Image\n",
    "\n",
    "# Register refcoco task\n",
    "tasks.register_task('caption', CaptionTask)\n",
    "\n",
    "# turn on cuda if GPU is available\n",
    "use_cuda = torch.cuda.is_available()\n",
    "# use fp16 only when GPU is available\n",
    "use_fp16 = False\n",
    "\n",
    "# Load pretrained ckpt & config\n",
    "overrides={\"eval_cider\":False, \"beam\":5, \"max_len_b\":16, \"no_repeat_ngram_size\":3, \"seed\":7}\n",
    "models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n",
    "        utils.split_paths('/data/mshukor/logs/ofa/checkpoints/caption/ofa_caption_stage_1/5_0.06_6000/checkpoint_best.pt'),\n",
    "        arg_overrides=overrides\n",
    "    )\n",
    "\n",
    "# Move models to GPU\n",
    "for model in models:\n",
    "    model.eval()\n",
    "    if use_fp16:\n",
    "        model.half()\n",
    "    if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n",
    "        model.cuda()\n",
    "    model.prepare_for_inference_(cfg)\n",
    "\n",
    "# Initialize generator\n",
    "generator = task.build_generator(models, cfg.generation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_ofa = models[0]\n",
    "ofa_state = model_ofa.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_state_given_key(state, key, excluded_keys=None):\n",
    "    new_state = {}\n",
    "    for k, v in state.items():\n",
    "        if key in k:\n",
    "            if excluded_keys is not None:\n",
    "                if not any([ek in k for ek in excluded_keys]):\n",
    "                    new_state[k] = v\n",
    "            else:\n",
    "                new_state[k] = v\n",
    "    return new_state\n",
    "\n",
    "key = 'encoder.layers.0'\n",
    "excluded_keys = ['embed', 'image']\n",
    "ofa_tmp = get_state_given_key(ofa_state, key, excluded_keys=excluded_keys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def get_ofa_args_large(args):\n",
    "#     args['encoder_embed_path'] = getattr(args, \"encoder_embed_path\", None)\n",
    "#     args['encoder_embed_dim'] = getattr(args, \"encoder_embed_dim\", 1024)\n",
    "#     args['encoder_ffn_embed_dim'] = getattr(args, \"encoder_ffn_embed_dim\", 4 * 1024)\n",
    "#     args['encoder_layers'] = getattr(args, \"encoder_layers\", 12)\n",
    "#     args['encoder_attention_heads'] = getattr(args, \"encoder_attention_heads\", 16)\n",
    "#     args['encoder_normalize_before'] = getattr(args, \"encoder_normalize_before\", True)\n",
    "#     args['encoder_learned_pos'] = getattr(args, \"encoder_learned_pos\", True)\n",
    "#     args['decoder_embed_path'] = getattr(args, \"decoder_embed_path\", None)\n",
    "#     args['decoder_embed_dim'] = getattr(args, \"decoder_embed_dim\", args['encoder_embed_dim'])\n",
    "#     args['decoder_ffn_embed_dim'] = getattr(\n",
    "#         args, \"decoder_ffn_embed_dim\", args['encoder_ffn_embed_dim']\n",
    "#     )\n",
    "#     args['decoder_layers'] = getattr(args, \"decoder_layers\", 12)\n",
    "#     args['decoder_attention_heads'] = getattr(args, \"decoder_attention_heads\", 16)\n",
    "#     args['decoder_normalize_before'] = getattr(args, \"decoder_normalize_before\", True)\n",
    "#     args['decoder_learned_pos'] = getattr(args, \"decoder_learned_pos\", True)\n",
    "#     args['attention_dropout'] = getattr(args, \"attention_dropout\", 0.0)\n",
    "#     args['relu_dropout'] = getattr(args, \"relu_dropout\", 0.0)\n",
    "#     args['dropout'] = getattr(args, \"dropout\", 0.0)\n",
    "#     args['max_target_positions'] = getattr(args, \"max_target_positions\", 1024)\n",
    "#     args['max_source_positions'] = getattr(args, \"max_source_positions\", 1024)\n",
    "#     args['adaptive_softmax_cutoff'] = getattr(args, \"adaptive_softmax_cutoff\", None)\n",
    "#     args['adaptive_softmax_dropout'] = getattr(args, \"adaptive_softmax_dropout\", 0)\n",
    "#     args['share_decoder_input_output_embed'] = getattr(\n",
    "#         args, \"share_decoder_input_output_embed\", True\n",
    "#     )\n",
    "#     args['share_all_embeddings'] = getattr(args, \"share_all_embeddings\", True)\n",
    "\n",
    "#     args['decoder_output_dim'] = getattr(\n",
    "#         args, \"decoder_output_dim\", args['decoder_embed_dim']\n",
    "#     )\n",
    "#     args['decoder_input_dim'] = getattr(args, \"decoder_input_dim\", args['decoder_embed_dim'])\n",
    "\n",
    "#     args['no_scale_embedding'] = getattr(args, \"no_scale_embedding\", True)\n",
    "#     args['layernorm_embedding'] = getattr(args, \"layernorm_embedding\", True)\n",
    "\n",
    "#     args['activation_fn'] = getattr(args, \"activation_fn\", \"gelu\")\n",
    "#     args['pooler_activation_fn'] = getattr(args, \"pooler_activation_fn\", \"tanh\")\n",
    "#     args['pooler_dropout'] = getattr(args, \"pooler_dropout\", 0.0)\n",
    "#     args['pooler_classifier'] = getattr(args, \"pooler_classifier\", \"mlp\")\n",
    "\n",
    "#     args['resnet_drop_path_rate'] = getattr(args, \"resnet_drop_path_rate\", 0.0)\n",
    "#     args['encoder_drop_path_rate'] = getattr(args, \"encoder_drop_path_rate\", 0.0)\n",
    "#     args['decoder_drop_path_rate'] = getattr(args, \"decoder_drop_path_rate\", 0.0)\n",
    "\n",
    "#     args['resnet_type'] = getattr(args, \"resnet_type\", \"resnet152\")\n",
    "#     args['token_bucket_size'] = getattr(args, \"token_bucket_size\", 256)\n",
    "#     args['image_bucket_size'] = getattr(args, \"image_bucket_size\", 42)\n",
    "\n",
    "#     args['freeze_encoder_embedding'] = getattr(args, \"freeze_encoder_embedding\", False)\n",
    "#     args['freeze_decoder_embedding'] = getattr(args, \"freeze_decoder_embedding\", False)\n",
    "#     args['add_type_embedding'] = getattr(args, \"add_type_embedding\", True)\n",
    "#     args['attn_scale_factor'] = getattr(args, \"attn_scale_factor\", 2)\n",
    "\n",
    "#     args['code_image_size'] = getattr(args, \"code_image_size\", 128)\n",
    "#     args['patch_layernorm_embedding'] = getattr(args, \"patch_layernorm_embedding\", True)\n",
    "#     args['code_layernorm_embedding'] = getattr(args, \"code_layernorm_embedding\", True)\n",
    "#     args['entangle_position_embedding'] = getattr(args, \"entangle_position_embedding\", False)\n",
    "#     args['disable_entangle'] = getattr(args, \"disable_entangle\", False)\n",
    "#     args['sync_bn'] = getattr(args, \"sync_bn\", False)\n",
    "\n",
    "#     args['scale_attn'] = getattr(args, \"scale_attn\", False)\n",
    "#     args['scale_fc'] = getattr(args, \"scale_fc\", False)\n",
    "#     args['scale_heads'] = getattr(args, \"scale_heads\", False)\n",
    "#     args['scale_resids'] = getattr(args, \"scale_resids\", False)\n",
    "\n",
    "#     args['orig_patch_image_size'] = getattr(args, \"orig_patch_image_size\", 256)\n",
    "\n",
    "#     return args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# args = {}\n",
    "# args = get_ofa_args_large(args)\n",
    "# args = BatchEncoding(args)\n",
    "# ofa_base_architecture(args)\n",
    "# data_dir = '/data/mshukor/data/ofa/pretrain_example'\n",
    "\n",
    "# cfg.task.neg_sample_dir = data_dir+'/negative_sample'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### convert t5 weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_t5 = T5ForConditionalGeneration.from_pretrained(\"t5-base\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_t5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "t5_state = model_t5.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "# line = re.sub(r\"</?\\[\\d+>\", \"\", line)\n",
    "\n",
    "mapping_dict = {\n",
    "    ## encoder\n",
    "    'block': 'layers', \n",
    "    'layer.[0-9]+.SelfAttention': 'self_attn', \n",
    "    '.q.': '.q_proj.', \n",
    "    '.k.weight': '.k_proj.weight', \n",
    "    '.v.': '.v_proj.', \n",
    "    # '.o.weight': '.out_proj.weight',    \n",
    "    'layer.0.layer_norm.': 'self_attn_layer_norm.', \n",
    "    'layer.[0-9]+.DenseReluDense.': '', \n",
    "    '.wi.': '.fc1.', \n",
    "    '.wo.': '.fc2.', \n",
    "    \n",
    "    \n",
    "    # decoder\n",
    "    'layer.[0-9]+.EncDecAttention': 'encoder_attn', \n",
    "    # 'layer.1.layer_norm.': 'encoder_attn_layer_norm.', \n",
    "    \n",
    "    \n",
    "}\n",
    "\n",
    "encoder_mapping = {\n",
    "    'layer.1.layer_norm.': 'final_layer_norm.',   \n",
    "}\n",
    "\n",
    "decoder_mapping = {\n",
    "    'layer.1.layer_norm.': 'encoder_attn_layer_norm.', \n",
    "    'layer.2.layer_norm.': 'final_layer_norm.', \n",
    "}\n",
    "\n",
    "\n",
    "simple_replace_mapping = {\n",
    "    \n",
    " '.o.weight': '.out_proj.weight',    \n",
    "}\n",
    "def modify_state(state, mapping_dict, encoder_mapping, decoder_mapping, simple_replace_mapping):\n",
    "    # orig_keys = ['block', 'layer.[0-9]+.SelfAttention', '.q.', '.k.', '.v.', '.o.', '0.layer_norm.', '.DenseReluDense.wi.', '.DenseReluDense.wo.', '.1.layer_norm.']\n",
    "    # new_keys = ['layers', 'layer.self_attn', '.q_proj.', '.k_proj.', '.v_proj.', '.out_proj.', '.self_attn_layer_norm.', '.fc1.', '.fc2.', '.final_layer_norm.']\n",
    "    \n",
    "    new_state = state.copy()\n",
    "    old_keys = []\n",
    "    for k, v in state.items():\n",
    "        \n",
    "        new_key = '%s' % k  \n",
    "        for old, new in simple_replace_mapping.items():\n",
    "            new_key = new_key.replace(old, new)\n",
    "            \n",
    "        for old, new in mapping_dict.items():\n",
    "            new_key = re.sub(r\"{}\".format(old), new, new_key)\n",
    "            \n",
    "        if 'encoder' in new_key:\n",
    "            for old, new in encoder_mapping.items():\n",
    "                new_key = re.sub(r\"{}\".format(old), new, new_key)\n",
    "            \n",
    "        if 'decoder' in new_key:\n",
    "            for old, new in decoder_mapping.items():\n",
    "                new_key = re.sub(r\"{}\".format(old), new, new_key)\n",
    "        \n",
    "        new_state[new_key] = v\n",
    "        old_keys.append(k)\n",
    "        \n",
    "        \n",
    "    \n",
    "    \n",
    "    for k in old_keys:\n",
    "        del new_state[k]\n",
    "        \n",
    "    final_state = {}\n",
    "    final_state['model'] = new_state\n",
    "    return final_state\n",
    "    \n",
    "new_state = modify_state(t5_state, mapping_dict, encoder_mapping, decoder_mapping, simple_replace_mapping)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_state['model'].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_states(state1, state2):\n",
    "    different = []\n",
    "    for k1 in state1.keys():\n",
    "        if k1 not in state2:\n",
    "            different.append(k1)\n",
    "    return different\n",
    "            \n",
    "tmp = compare_states(new_state, ofa_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_path = '/data/mshukor/logs/ofa/pretrained_models/t5_base.pt'\n",
    "torch.save(new_state, output_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_path = '/data/mshukor/logs/ofa/pretrained_models/t5_base.pt'\n",
    "\n",
    "tmp_state = torch.load(output_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model_ofa.load_state_dict(tmp_state['model'],  strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp_state['model'].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp_state = torch.load('/data/mshukor/logs/ofa/pretrained_models/ofa_base.pt')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['args', 'cfg', 'model', 'criterion', 'optimizer_history', 'task_state', 'extra_state', 'last_optimizer_state'])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tmp_state.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_t5.encoder.block[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_ofa.encoder.layers[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### convert BART weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights_path = '/data/mshukor/logs/ofa/pretrained_models/bart.base/model.pt'\n",
    "bart_state = torch.load(weights_path, map_location=torch.device('cpu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_ofa.load_state_dict(bart_state['model'],  strict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['encoder.version', 'encoder.embed_tokens.weight', 'encoder.embed_positions.weight', 'encoder.layers.0.self_attn.k_proj.weight', 'encoder.layers.0.self_attn.k_proj.bias', 'encoder.layers.0.self_attn.v_proj.weight', 'encoder.layers.0.self_attn.v_proj.bias', 'encoder.layers.0.self_attn.q_proj.weight', 'encoder.layers.0.self_attn.q_proj.bias', 'encoder.layers.0.self_attn.out_proj.weight', 'encoder.layers.0.self_attn.out_proj.bias', 'encoder.layers.0.self_attn_layer_norm.weight', 'encoder.layers.0.self_attn_layer_norm.bias', 'encoder.layers.0.fc1.weight', 'encoder.layers.0.fc1.bias', 'encoder.layers.0.fc2.weight', 'encoder.layers.0.fc2.bias', 'encoder.layers.0.final_layer_norm.weight', 'encoder.layers.0.final_layer_norm.bias', 'encoder.layers.1.self_attn.k_proj.weight', 'encoder.layers.1.self_attn.k_proj.bias', 'encoder.layers.1.self_attn.v_proj.weight', 'encoder.layers.1.self_attn.v_proj.bias', 'encoder.layers.1.self_attn.q_proj.weight', 'encoder.layers.1.self_attn.q_proj.bias', 'encoder.layers.1.self_attn.out_proj.weight', 'encoder.layers.1.self_attn.out_proj.bias', 'encoder.layers.1.self_attn_layer_norm.weight', 'encoder.layers.1.self_attn_layer_norm.bias', 'encoder.layers.1.fc1.weight', 'encoder.layers.1.fc1.bias', 'encoder.layers.1.fc2.weight', 'encoder.layers.1.fc2.bias', 'encoder.layers.1.final_layer_norm.weight', 'encoder.layers.1.final_layer_norm.bias', 'encoder.layers.2.self_attn.k_proj.weight', 'encoder.layers.2.self_attn.k_proj.bias', 'encoder.layers.2.self_attn.v_proj.weight', 'encoder.layers.2.self_attn.v_proj.bias', 'encoder.layers.2.self_attn.q_proj.weight', 'encoder.layers.2.self_attn.q_proj.bias', 'encoder.layers.2.self_attn.out_proj.weight', 'encoder.layers.2.self_attn.out_proj.bias', 'encoder.layers.2.self_attn_layer_norm.weight', 'encoder.layers.2.self_attn_layer_norm.bias', 'encoder.layers.2.fc1.weight', 'encoder.layers.2.fc1.bias', 'encoder.layers.2.fc2.weight', 'encoder.layers.2.fc2.bias', 'encoder.layers.2.final_layer_norm.weight', 'encoder.layers.2.final_layer_norm.bias', 'encoder.layers.3.self_attn.k_proj.weight', 'encoder.layers.3.self_attn.k_proj.bias', 'encoder.layers.3.self_attn.v_proj.weight', 'encoder.layers.3.self_attn.v_proj.bias', 'encoder.layers.3.self_attn.q_proj.weight', 'encoder.layers.3.self_attn.q_proj.bias', 'encoder.layers.3.self_attn.out_proj.weight', 'encoder.layers.3.self_attn.out_proj.bias', 'encoder.layers.3.self_attn_layer_norm.weight', 'encoder.layers.3.self_attn_layer_norm.bias', 'encoder.layers.3.fc1.weight', 'encoder.layers.3.fc1.bias', 'encoder.layers.3.fc2.weight', 'encoder.layers.3.fc2.bias', 'encoder.layers.3.final_layer_norm.weight', 'encoder.layers.3.final_layer_norm.bias', 'encoder.layers.4.self_attn.k_proj.weight', 'encoder.layers.4.self_attn.k_proj.bias', 'encoder.layers.4.self_attn.v_proj.weight', 'encoder.layers.4.self_attn.v_proj.bias', 'encoder.layers.4.self_attn.q_proj.weight', 'encoder.layers.4.self_attn.q_proj.bias', 'encoder.layers.4.self_attn.out_proj.weight', 'encoder.layers.4.self_attn.out_proj.bias', 'encoder.layers.4.self_attn_layer_norm.weight', 'encoder.layers.4.self_attn_layer_norm.bias', 'encoder.layers.4.fc1.weight', 'encoder.layers.4.fc1.bias', 'encoder.layers.4.fc2.weight', 'encoder.layers.4.fc2.bias', 'encoder.layers.4.final_layer_norm.weight', 'encoder.layers.4.final_layer_norm.bias', 'encoder.layers.5.self_attn.k_proj.weight', 'encoder.layers.5.self_attn.k_proj.bias', 'encoder.layers.5.self_attn.v_proj.weight', 'encoder.layers.5.self_attn.v_proj.bias', 'encoder.layers.5.self_attn.q_proj.weight', 'encoder.layers.5.self_attn.q_proj.bias', 'encoder.layers.5.self_attn.out_proj.weight', 'encoder.layers.5.self_attn.out_proj.bias', 'encoder.layers.5.self_attn_layer_norm.weight', 'encoder.layers.5.self_attn_layer_norm.bias', 'encoder.layers.5.fc1.weight', 'encoder.layers.5.fc1.bias', 'encoder.layers.5.fc2.weight', 'encoder.layers.5.fc2.bias', 'encoder.layers.5.final_layer_norm.weight', 'encoder.layers.5.final_layer_norm.bias', 'encoder.layernorm_embedding.weight', 'encoder.layernorm_embedding.bias', 'decoder.version', 'decoder.embed_tokens.weight', 'decoder.embed_positions.weight', 'decoder.layers.0.self_attn.k_proj.weight', 'decoder.layers.0.self_attn.k_proj.bias', 'decoder.layers.0.self_attn.v_proj.weight', 'decoder.layers.0.self_attn.v_proj.bias', 'decoder.layers.0.self_attn.q_proj.weight', 'decoder.layers.0.self_attn.q_proj.bias', 'decoder.layers.0.self_attn.out_proj.weight', 'decoder.layers.0.self_attn.out_proj.bias', 'decoder.layers.0.self_attn_layer_norm.weight', 'decoder.layers.0.self_attn_layer_norm.bias', 'decoder.layers.0.encoder_attn.k_proj.weight', 'decoder.layers.0.encoder_attn.k_proj.bias', 'decoder.layers.0.encoder_attn.v_proj.weight', 'decoder.layers.0.encoder_attn.v_proj.bias', 'decoder.layers.0.encoder_attn.q_proj.weight', 'decoder.layers.0.encoder_attn.q_proj.bias', 'decoder.layers.0.encoder_attn.out_proj.weight', 'decoder.layers.0.encoder_attn.out_proj.bias', 'decoder.layers.0.encoder_attn_layer_norm.weight', 'decoder.layers.0.encoder_attn_layer_norm.bias', 'decoder.layers.0.fc1.weight', 'decoder.layers.0.fc1.bias', 'decoder.layers.0.fc2.weight', 'decoder.layers.0.fc2.bias', 'decoder.layers.0.final_layer_norm.weight', 'decoder.layers.0.final_layer_norm.bias', 'decoder.layers.1.self_attn.k_proj.weight', 'decoder.layers.1.self_attn.k_proj.bias', 'decoder.layers.1.self_attn.v_proj.weight', 'decoder.layers.1.self_attn.v_proj.bias', 'decoder.layers.1.self_attn.q_proj.weight', 'decoder.layers.1.self_attn.q_proj.bias', 'decoder.layers.1.self_attn.out_proj.weight', 'decoder.layers.1.self_attn.out_proj.bias', 'decoder.layers.1.self_attn_layer_norm.weight', 'decoder.layers.1.self_attn_layer_norm.bias', 'decoder.layers.1.encoder_attn.k_proj.weight', 'decoder.layers.1.encoder_attn.k_proj.bias', 'decoder.layers.1.encoder_attn.v_proj.weight', 'decoder.layers.1.encoder_attn.v_proj.bias', 'decoder.layers.1.encoder_attn.q_proj.weight', 'decoder.layers.1.encoder_attn.q_proj.bias', 'decoder.layers.1.encoder_attn.out_proj.weight', 'decoder.layers.1.encoder_attn.out_proj.bias', 'decoder.layers.1.encoder_attn_layer_norm.weight', 'decoder.layers.1.encoder_attn_layer_norm.bias', 'decoder.layers.1.fc1.weight', 'decoder.layers.1.fc1.bias', 'decoder.layers.1.fc2.weight', 'decoder.layers.1.fc2.bias', 'decoder.layers.1.final_layer_norm.weight', 'decoder.layers.1.final_layer_norm.bias', 'decoder.layers.2.self_attn.k_proj.weight', 'decoder.layers.2.self_attn.k_proj.bias', 'decoder.layers.2.self_attn.v_proj.weight', 'decoder.layers.2.self_attn.v_proj.bias', 'decoder.layers.2.self_attn.q_proj.weight', 'decoder.layers.2.self_attn.q_proj.bias', 'decoder.layers.2.self_attn.out_proj.weight', 'decoder.layers.2.self_attn.out_proj.bias', 'decoder.layers.2.self_attn_layer_norm.weight', 'decoder.layers.2.self_attn_layer_norm.bias', 'decoder.layers.2.encoder_attn.k_proj.weight', 'decoder.layers.2.encoder_attn.k_proj.bias', 'decoder.layers.2.encoder_attn.v_proj.weight', 'decoder.layers.2.encoder_attn.v_proj.bias', 'decoder.layers.2.encoder_attn.q_proj.weight', 'decoder.layers.2.encoder_attn.q_proj.bias', 'decoder.layers.2.encoder_attn.out_proj.weight', 'decoder.layers.2.encoder_attn.out_proj.bias', 'decoder.layers.2.encoder_attn_layer_norm.weight', 'decoder.layers.2.encoder_attn_layer_norm.bias', 'decoder.layers.2.fc1.weight', 'decoder.layers.2.fc1.bias', 'decoder.layers.2.fc2.weight', 'decoder.layers.2.fc2.bias', 'decoder.layers.2.final_layer_norm.weight', 'decoder.layers.2.final_layer_norm.bias', 'decoder.layers.3.self_attn.k_proj.weight', 'decoder.layers.3.self_attn.k_proj.bias', 'decoder.layers.3.self_attn.v_proj.weight', 'decoder.layers.3.self_attn.v_proj.bias', 'decoder.layers.3.self_attn.q_proj.weight', 'decoder.layers.3.self_attn.q_proj.bias', 'decoder.layers.3.self_attn.out_proj.weight', 'decoder.layers.3.self_attn.out_proj.bias', 'decoder.layers.3.self_attn_layer_norm.weight', 'decoder.layers.3.self_attn_layer_norm.bias', 'decoder.layers.3.encoder_attn.k_proj.weight', 'decoder.layers.3.encoder_attn.k_proj.bias', 'decoder.layers.3.encoder_attn.v_proj.weight', 'decoder.layers.3.encoder_attn.v_proj.bias', 'decoder.layers.3.encoder_attn.q_proj.weight', 'decoder.layers.3.encoder_attn.q_proj.bias', 'decoder.layers.3.encoder_attn.out_proj.weight', 'decoder.layers.3.encoder_attn.out_proj.bias', 'decoder.layers.3.encoder_attn_layer_norm.weight', 'decoder.layers.3.encoder_attn_layer_norm.bias', 'decoder.layers.3.fc1.weight', 'decoder.layers.3.fc1.bias', 'decoder.layers.3.fc2.weight', 'decoder.layers.3.fc2.bias', 'decoder.layers.3.final_layer_norm.weight', 'decoder.layers.3.final_layer_norm.bias', 'decoder.layers.4.self_attn.k_proj.weight', 'decoder.layers.4.self_attn.k_proj.bias', 'decoder.layers.4.self_attn.v_proj.weight', 'decoder.layers.4.self_attn.v_proj.bias', 'decoder.layers.4.self_attn.q_proj.weight', 'decoder.layers.4.self_attn.q_proj.bias', 'decoder.layers.4.self_attn.out_proj.weight', 'decoder.layers.4.self_attn.out_proj.bias', 'decoder.layers.4.self_attn_layer_norm.weight', 'decoder.layers.4.self_attn_layer_norm.bias', 'decoder.layers.4.encoder_attn.k_proj.weight', 'decoder.layers.4.encoder_attn.k_proj.bias', 'decoder.layers.4.encoder_attn.v_proj.weight', 'decoder.layers.4.encoder_attn.v_proj.bias', 'decoder.layers.4.encoder_attn.q_proj.weight', 'decoder.layers.4.encoder_attn.q_proj.bias', 'decoder.layers.4.encoder_attn.out_proj.weight', 'decoder.layers.4.encoder_attn.out_proj.bias', 'decoder.layers.4.encoder_attn_layer_norm.weight', 'decoder.layers.4.encoder_attn_layer_norm.bias', 'decoder.layers.4.fc1.weight', 'decoder.layers.4.fc1.bias', 'decoder.layers.4.fc2.weight', 'decoder.layers.4.fc2.bias', 'decoder.layers.4.final_layer_norm.weight', 'decoder.layers.4.final_layer_norm.bias', 'decoder.layers.5.self_attn.k_proj.weight', 'decoder.layers.5.self_attn.k_proj.bias', 'decoder.layers.5.self_attn.v_proj.weight', 'decoder.layers.5.self_attn.v_proj.bias', 'decoder.layers.5.self_attn.q_proj.weight', 'decoder.layers.5.self_attn.q_proj.bias', 'decoder.layers.5.self_attn.out_proj.weight', 'decoder.layers.5.self_attn.out_proj.bias', 'decoder.layers.5.self_attn_layer_norm.weight', 'decoder.layers.5.self_attn_layer_norm.bias', 'decoder.layers.5.encoder_attn.k_proj.weight', 'decoder.layers.5.encoder_attn.k_proj.bias', 'decoder.layers.5.encoder_attn.v_proj.weight', 'decoder.layers.5.encoder_attn.v_proj.bias', 'decoder.layers.5.encoder_attn.q_proj.weight', 'decoder.layers.5.encoder_attn.q_proj.bias', 'decoder.layers.5.encoder_attn.out_proj.weight', 'decoder.layers.5.encoder_attn.out_proj.bias', 'decoder.layers.5.encoder_attn_layer_norm.weight', 'decoder.layers.5.encoder_attn_layer_norm.bias', 'decoder.layers.5.fc1.weight', 'decoder.layers.5.fc1.bias', 'decoder.layers.5.fc2.weight', 'decoder.layers.5.fc2.bias', 'decoder.layers.5.final_layer_norm.weight', 'decoder.layers.5.final_layer_norm.bias', 'decoder.layernorm_embedding.weight', 'decoder.layernorm_embedding.bias'])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bart_state['model'].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0125,  0.0014, -0.0096,  ...,  0.0022,  0.1057,  0.0103],\n",
       "        [-0.0114, -0.0169, -0.0184,  ..., -0.0131, -0.0043, -0.0053],\n",
       "        [ 0.0842, -0.0389,  0.0096,  ...,  0.0583,  0.0082,  0.0357],\n",
       "        ...,\n",
       "        [-0.0032, -0.0313, -0.1026,  ...,  0.0138,  0.0056, -0.0023],\n",
       "        [ 0.0104, -0.0045,  0.0263,  ...,  0.0158,  0.0324, -0.0111],\n",
       "        [-0.0038, -0.0532, -0.0147,  ...,  0.0067,  0.0256,  0.0009]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ofa_state.keys()\n",
    "ofa_state['encoder.embed_tokens.weight']"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ofa",
   "language": "python",
   "name": "ofa"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
